# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.components.retrievers.types import TextRetriever
from haystack.core.serialization import component_to_dict
from haystack.utils.deserialization import deserialize_component_inplace


@component
class MultiQueryTextRetriever:
    """
    A component that retrieves documents using multiple queries in parallel with a text-based retriever.

    This component takes a list of text queries and uses a text-based retriever to find relevant documents for each
    query in parallel, using a thread pool to manage concurrent execution. The results are combined and sorted by
    relevance score.

    You can use this component in combination with QueryExpander component to enhance the retrieval process.

    ### Usage example
    ```python
    from haystack import Document
    from haystack.components.writers import DocumentWriter
    from haystack.document_stores.in_memory import InMemoryDocumentStore
    from haystack.document_stores.types import DuplicatePolicy
    from haystack.components.retrievers import InMemoryBM25Retriever
    from haystack.components.query import QueryExpander
    from haystack.components.retrievers.multi_query_text_retriever import MultiQueryTextRetriever

    documents = [
        Document(content="Renewable energy is energy that is collected from renewable resources."),
        Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
        Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
        Document(content="Hydropower is a form of renewable energy using the flow of water to generate electricity."),
        Document(content="Geothermal energy is heat that comes from the sub-surface of the earth.")
    ]

    document_store = InMemoryDocumentStore()
    doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
    doc_writer.run(documents=documents)

    in_memory_retriever = InMemoryBM25Retriever(document_store=document_store, top_k=1)
    multiquery_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
    results = multiquery_retriever.run(queries=["renewable energy?", "Geothermal", "Hydropower"])
    for doc in results["documents"]:
        print(f"Content: {doc.content}, Score: {doc.score}")
    >>
    >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 1.6474448833731097
    >> Content: Hydropower is a form of renewable energy using the flow of water to generate electricity., Score: 1.615
    >> Content: Renewable energy is energy that is collected from renewable resources., Score: 1.5255309812344944
    ```
    """  # noqa E501

    def __init__(self, *, retriever: TextRetriever, max_workers: int = 3) -> None:
        """
        Initialize MultiQueryTextRetriever.

        :param retriever: The text-based retriever to use for document retrieval.
        :param max_workers: Maximum number of worker threads for parallel processing. Default is 3.
        """
        self.retriever = retriever
        self.max_workers = max_workers
        self._is_warmed_up = False

    def warm_up(self) -> None:
        """
        Warm up the retriever if it has a warm_up method.
        """
        if not self._is_warmed_up:
            if hasattr(self.retriever, "warm_up") and callable(getattr(self.retriever, "warm_up")):
                self.retriever.warm_up()
            self._is_warmed_up = True

    @component.output_types(documents=list[Document])
    def run(self, queries: list[str], retriever_kwargs: Optional[dict[str, Any]] = None) -> dict[str, list[Document]]:
        """
        Retrieve documents using multiple queries in parallel.

        :param queries: List of text queries to process.
        :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
        :returns:
            A dictionary containing:
                `documents`: List of retrieved documents sorted by relevance score.
        """
        docs: list[Document] = []
        seen_contents = set()
        retriever_kwargs = retriever_kwargs or {}

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries)
            for result in queries_results:
                if not result:
                    continue
                # deduplicate based on content
                for doc in result:
                    if doc.content not in seen_contents:
                        docs.append(doc)
                        seen_contents.add(doc.content)

        docs.sort(key=lambda x: x.score or 0.0, reverse=True)
        return {"documents": docs}

    def _run_on_thread(self, query: str, retriever_kwargs: Optional[dict[str, Any]] = None) -> Optional[list[Document]]:
        """
        Process a single query on a separate thread.

        :param query: The text query to process.
        :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
        :returns:
            List of retrieved documents or None if no results.
        """
        result = self.retriever.run(query=query, **(retriever_kwargs or {}))
        if result and "documents" in result:
            return result["documents"]
        return None

    def to_dict(self) -> dict[str, Any]:
        """
        Serializes the component to a dictionary.

        :returns:
            The serialized component as a dictionary.
        """
        return default_to_dict(
            self, retriever=component_to_dict(obj=self.retriever, name="retriever"), max_workers=self.max_workers
        )

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "MultiQueryTextRetriever":
        """
        Deserializes the component from a dictionary.

        :param data: The dictionary to deserialize from.
        :returns:
            The deserialized component.
        """
        deserialize_component_inplace(data["init_parameters"], key="retriever")
        return default_from_dict(cls, data)
